#!/usr/bin/env python3

"""
Generate Rust SSL error mapping code from OpenSSL sources.

This is based on CPython's Tools/ssl/make_ssl_data.py but generates
Rust code instead of C headers.

It takes two arguments:
- the path to the OpenSSL source tree (e.g. git checkout)
- the path to the Rust file to be generated (e.g. stdlib/src/ssl/ssl_data.rs)
- error codes are version specific
"""

import argparse
import datetime
import operator
import os
import re
import sys


parser = argparse.ArgumentParser(
    description="Generate ssl_data.rs from OpenSSL sources"
)
parser.add_argument("srcdir", help="OpenSSL source directory")
parser.add_argument("output", nargs="?", default=None)


def _file_search(fname, pat):
    with open(fname, encoding="utf-8") as f:
        for line in f:
            match = pat.search(line)
            if match is not None:
                yield match


def parse_err_h(args):
    """Parse err codes, e.g. ERR_LIB_X509: 11"""
    pat = re.compile(r"#\s*define\W+ERR_LIB_(\w+)\s+(\d+)")
    lib2errnum = {}
    for match in _file_search(args.err_h, pat):
        libname, num = match.groups()
        lib2errnum[libname] = int(num)

    return lib2errnum


def parse_openssl_error_text(args):
    """Parse error reasons, X509_R_AKID_MISMATCH"""
    # ignore backslash line continuation for now
    pat = re.compile(r"^((\w+?)_R_(\w+)):(\d+):")
    for match in _file_search(args.errtxt, pat):
        reason, libname, errname, num = match.groups()
        if "_F_" in reason:
            # ignore function codes
            continue
        num = int(num)
        yield reason, libname, errname, num


def parse_extra_reasons(args):
    """Parse extra reasons from openssl.ec"""
    pat = re.compile(r"^R\s+((\w+)_R_(\w+))\s+(\d+)")
    for match in _file_search(args.errcodes, pat):
        reason, libname, errname, num = match.groups()
        num = int(num)
        yield reason, libname, errname, num


def gen_library_codes_rust(args):
    """Generate Rust phf map for library codes"""
    yield "// Maps lib_code -> library name"
    yield '// Example: 20 -> "SSL"'
    yield "pub static LIBRARY_CODES: phf::Map<u32, &'static str> = phf_map! {"

    # Deduplicate: keep the last one if there are duplicates
    seen = {}
    for libname in sorted(args.lib2errnum):
        lib_num = args.lib2errnum[libname]
        seen[lib_num] = libname

    for lib_num in sorted(seen.keys()):
        libname = seen[lib_num]
        yield f'    {lib_num}u32 => "{libname}",'
    yield "};"
    yield ""


def gen_error_codes_rust(args):
    """Generate Rust phf map for error codes"""
    yield "// Maps encoded (lib, reason) -> error mnemonic"
    yield '// Example: encode_error_key(20, 134) -> "CERTIFICATE_VERIFY_FAILED"'
    yield "// Key encoding: (lib << 32) | reason"
    yield "pub static ERROR_CODES: phf::Map<u64, &'static str> = phf_map! {"
    for reason, libname, errname, num in args.reasons:
        if libname not in args.lib2errnum:
            continue
        lib_num = args.lib2errnum[libname]
        # Encode (lib, reason) as single u64
        key = (lib_num << 32) | num
        yield f'    {key}u64 => "{errname}",'
    yield "};"
    yield ""


def main():
    args = parser.parse_args()

    args.err_h = os.path.join(args.srcdir, "include", "openssl", "err.h")
    if not os.path.isfile(args.err_h):
        # Fall back to infile for OpenSSL 3.0.0
        args.err_h += ".in"
    args.errcodes = os.path.join(args.srcdir, "crypto", "err", "openssl.ec")
    args.errtxt = os.path.join(args.srcdir, "crypto", "err", "openssl.txt")

    if not os.path.isfile(args.errtxt):
        parser.error(f"File {args.errtxt} not found in srcdir\n.")

    # {X509: 11, ...}
    args.lib2errnum = parse_err_h(args)

    # [('X509_R_AKID_MISMATCH', 'X509', 'AKID_MISMATCH', 110), ...]
    reasons = []
    reasons.extend(parse_openssl_error_text(args))
    reasons.extend(parse_extra_reasons(args))
    # sort by libname, numeric error code
    args.reasons = sorted(reasons, key=operator.itemgetter(0, 3))

    lines = [
        "// File generated by tools/make_ssl_data_rs.py",
        f"// Generated on {datetime.datetime.now(datetime.timezone.utc).isoformat()}",
        f"// Source: OpenSSL from {args.srcdir}",
        "// spell-checker: disable",
        "",
        "use phf::phf_map;",
        "",
    ]
    lines.extend(gen_library_codes_rust(args))
    lines.extend(gen_error_codes_rust(args))

    # Add helper function
    lines.extend(
        [
            "/// Helper function to create encoded key from (lib, reason) pair",
            "#[inline]",
            "pub fn encode_error_key(lib: i32, reason: i32) -> u64 {",
            "    ((lib as u64) << 32) | (reason as u64 & 0xFFFFFFFF)",
            "}",
            "",
        ]
    )

    if args.output is None:
        for line in lines:
            print(line)
    else:
        with open(args.output, "w") as output:
            for line in lines:
                print(line, file=output)

        print(f"Generated {args.output}")
        print(f"Found {len(args.lib2errnum)} library codes")
        print(f"Found {len(args.reasons)} error codes")


if __name__ == "__main__":
    main()
